import concurrent.futures
import numpy as np
import cv2
import glob

def shift_img(img, x, y):
    M = np.float32([[1, 0, x],

                    [0, 1, y]])

    shifted = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))

    return shifted

def binary_diff_mask(clean, dirty, thresold=0.3):
    # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
    clean = np.power(clean, 1/2.2)
    dirty = np.power(dirty, 1/2.2)

    # averaged_per_pixel = np.abs(dirty / shift_img(clean, 5, 0) - 1)
    # print(averaged_per_pixel)
    diff = np.abs(clean - dirty)

    bin_diff = (diff > thresold).astype(np.uint8)

    return bin_diff

clean = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
clean = sorted(clean)
dirty = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
dirty = sorted(dirty)

clean_img = cv2.imread(clean[34])
dirty_img = cv2.imread(dirty[34])

binary_diff_mask_img = binary_diff_mask(dirty_img/255, clean_img/255, thresold=0.05)

k = 20

def process(i, j):
    print(i)
    clean_img_copy = shift_img(clean_img, (i-k)/4, (j-k)/4)
    binary_diff_mask_img = binary_diff_mask(dirty_img / 255, clean_img_copy / 255, threshold=0.3)
    success = cv2.imwrite(f"test/test_img_x{(i-k)/4}-y{(j-k)/4}.png", binary_diff_mask_img*255)
    if not success:
        print(f"Failed to save image at test/test_img{(i - k) / 4}-{(j - k) / 4}.png")


with concurrent.futures.ProcessPoolExecutor() as executor:
    for i in range(k*2):
        for j in range(k*2):
            executor.submit(process, i, j)
